import numpy as np
from experiment_setting import Experiment
from util import feasibility_check, optimal_solution, glrt, beta
from sampling_rule import EA, DualSR, AFSR, AFOSR, SEQ_OFSR
from LinearModel import OLSEstimator

class Agent:
    def __init__(self, n0, delta, sampling_rule):
        self.sampling_rule = sampling_rule
        self.EXP = Experiment()
        self.dist = self.EXP.dist
        self.M = self.EXP.M
        self.K = self.EXP.K
        self.d = self.EXP.d
        self.n0 = n0
        self.delta = delta
        self.f = self.EXP.f
        self.g = self.EXP.g
        self.b = self.EXP.b
        self.variance = self.EXP.variance
        self.std = self.EXP.std
        self.opt_solution = self.EXP.opt_solution

    def reset(self):
        self.F_esti = np.zeros((self.M, self.K))
        self.G_esti = np.zeros((self.M, self.K))
        self.f_esti = np.zeros((self.M, self.K))
        self.g_esti = np.zeros((self.M, self.K))
        self.feasibility_esti = np.ones((self.M, self.K))
        self.opt_solution_esti = np.zeros(self.M, dtype=int)
        self.alternative_count = np.zeros((self.M, self.K))
        self.lin_Estimators = OLSEstimator(self.EXP)
        self.Z_mat, self.Temp = self.lin_Estimators.compute_Z()
        self.lambda_ = np.ones(self.K * self.M) / (self.K * self.M)
        self.flag = False

    def sample(self):
        next_context, next_design = None, None
        if self.sampling_rule == "EA": #USR Algorithm
            next_context, next_design = EA(self.alternative_count, self.EXP.design_indices)
        elif self.sampling_rule == "DualSR": #DSR Algorithm
            next_context, next_design, lambda_ = DualSR(self.K, self.M, self.d, self.b, self.f_esti, self.g_esti, self.opt_solution_esti, self.feasibility_esti,
                                               self.EXP.phi, self.Z_mat, self.Temp, self.n0, self.alternative_count, self.lambda_, self.EXP.design_indices, self.EXP.design_var, self.flag)
            if lambda_ is not None:
                self.lambda_ = lambda_
        elif self.sampling_rule == "AFSR": #BCSR Algorithm
            next_context, next_design = AFSR(self.EXP, self.alternative_count, self.n0, self.d, self.f_esti, self.EXP.design_indices)
        elif self.sampling_rule == "AFOSR": #GFSR Algorithm
            next_context, next_design = AFOSR(self.EXP, self.K, self.M, self.d, self.b, self.f_esti, self.g_esti,
                                              self.alternative_count, self.EXP.phi, self.Temp, self.n0, self.EXP.design_indices)
        elif self.sampling_rule == "SEQ_OFSR": #GOSR Algorithm
            next_context, next_design = SEQ_OFSR(self.EXP, self.K, self.M, self.d, self.f_esti, self.EXP.phi,
                                                 self.Temp, self.n0, self.alternative_count, self.EXP.design_indices)
        return next_context, next_design

    def step(self, context, design):
        self.alternative_count[context, design] += 1
        F, G = self.EXP.generate_samples(context, design)
        self.F_esti[context, design] += (F - self.F_esti[context, design]) / self.alternative_count[context, design]
        self.G_esti[context, design] += (G - self.G_esti[context, design]) / self.alternative_count[context, design]

        if np.sum(self.alternative_count) >= self.d:
            self.lin_Estimators.updateModel(self.F_esti, self.G_esti)
            self.f_esti, self.g_esti = self.lin_Estimators.predict()

        self.feasibility_esti = feasibility_check(self.M, self.K, self.g_esti, self.b)
        old_opt_solution = self.opt_solution_esti
        self.opt_solution_esti = optimal_solution(self.M, self.f_esti, self.feasibility_esti)
        if not np.array_equal(old_opt_solution, self.opt_solution_esti):
            self.flag = True
        else:
            self.flag = False

    def select(self):
        return self.opt_solution_esti

    def stop(self):
        if np.sum(self.alternative_count) < self.d * self.n0:
            return False
        else:
            reg_variance, diff_reg_variance = self.lin_Estimators.compute_regression_variance(self.alternative_count, self.opt_solution_esti)
            val = glrt(self.M, self.K, self.b, self.f_esti , self.g_esti, self.opt_solution_esti, self.feasibility_esti, reg_variance, diff_reg_variance)
            t = np.sum(self.alternative_count)
            if t * val > beta(t, self.delta):
                return True
            else:
                return False














